# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Load Widerface dataset."""
import os
import copy
import json
import cv2
from mindvision.common.utils.class_factory import ClassFactory, ModuleType


@ClassFactory.register(ModuleType.DATASET)
class WiderfaceDataset:
    """RetinaFace Dataset for Widerface."""
    def __init__(self, label_path, is_training=True, eval_imageid_file=None):
        self.images_list = []
        self.labels_list = []
        self.image_anno_dict = {}
        self.is_training = is_training
        self.eval_imageid_file = eval_imageid_file
        if not is_training:
            json_prefix = self.eval_imageid_file
            f = open(label_path, 'r')  # read annotation file
            lines = f.readlines()  # read line by line
            count = 0
            x = {}

            for line in lines:
                line = line.rstrip()
                if line.startswith('#'):
                    path = line[2:]
                    path = label_path.replace('label.txt', 'images/') + path  # image path
                    assert os.path.exists(path), 'image path is not exists.'
                    self.images_list.append(path)
                    x[path] = count
                    count = count + 1
                    json_str = json.dumps(x, indent=2)
            with open('./' + json_prefix, 'w') as json_file:
                json_file.write(json_str)
        else:
            f = open(label_path, 'r')
            lines = f.readlines()
            first = True
            labels = []
            for line in lines:
                line = line.rstrip()
                if line.startswith('#'):
                    if first is True:
                        first = False
                    else:
                        c_labels = copy.deepcopy(labels)
                        self.labels_list.append(c_labels)
                        labels.clear()
                    # remove '# '
                    path = line[2:]
                    path = label_path.replace('label.txt', 'images/') + path

                    assert os.path.exists(path), 'image path is not exists.'

                    self.images_list.append(path)
                else:
                    line = line.split(' ')
                    label = [float(x) for x in line]
                    labels.append(label)
            # add the last label
            self.labels_list.append(labels)

            # del bbox which width is zero or height is zero
            for i in range(len(self.labels_list) - 1, -1, -1):
                labels = self.labels_list[i]
                for j in range(len(labels) - 1, -1, -1):
                    label = labels[j]
                    if label[2] <= 0 or label[3] <= 0:
                        labels.pop(j)
                if not labels:
                    self.images_list.pop(i)
                    self.labels_list.pop(i)
                else:
                    self.labels_list[i] = labels
            for i in range(len(self.images_list)):
                self.image_anno_dict[self.images_list[i]] = self.labels_list[i]

        self.num_samples = len(self.images_list)
        self.eval_imageid_file = eval_imageid_file

    def __getitem__(self, index):
        if not self.is_training:
            img = cv2.imread(self.images_list[index])
            return img, index
        img_path = self.images_list[index]
        img = cv2.imread(img_path)
        image_id = self.image_anno_dict[img_path]
        out_target = []
        for _, label in enumerate(image_id):
            # get bbox
            bbox = xywh2xyxy(label[0:4])
            # set flag
            if label[4] < 0:
                flag = -1
            else:
                flag = 1
            out_target.append(bbox + label[4:6] + label[7:9] + label[10:12] + label[13:15] + label[16:18] + [flag])
        return img, out_target

    def __len__(self):
        return self.num_samples


def xywh2xyxy(bbox):
    """xywh convert into xyxy format."""
    x_min = bbox[0]
    y_min = bbox[1]
    w = bbox[2]
    h = bbox[3]
    return [x_min, y_min, x_min + w, y_min + h]
